def calculate_heterogeneity_score(
    client_distribution, global_distribution, num_classes
):
    """Calculate how different client's class distribution is from global distribution"""

    score = 0
    for cls in range(num_classes):
        client_prob = client_distribution.get(cls, 0)
        global_prob = global_distribution.get(cls, 1 / num_classes)
        if client_prob > 0 and global_prob > 0:
            ratio = client_prob / global_prob
            score += abs(ratio - 1)

    return min(score / num_classes, 1)


def calculate_global_distribution(client_class_distributions, num_classes):
    """Calculate approximate global class distribution from client distributions"""

    global_dist = {}
    n_clients = len(client_class_distributions)

    for cls in range(num_classes):
        global_dist[cls] = (
            sum(dist.get(cls, 0) for dist in client_class_distributions) / n_clients
        )

    # Avoid division-by-zero or zero probabilities
    for cls in range(num_classes):
        if global_dist[cls] < 0.01:
            global_dist[cls] = 0.01

    total = sum(global_dist.values())
    global_dist = {k: v / total for k, v in global_dist.items()}

    return global_dist
